import openai
import os
import torch
import open_clip
from PIL import Image
from pycocotools.coco import COCO


def generate_description_openai(image_caption, clip_features):
    prompt = f"Here is an image with the caption: '{image_caption}'. "
    prompt += f"Based on this caption and the visual features represented by this embedding '{clip_features}', please generate a new detailed description."
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that generates captions for images."},
            {"role": "user", "content": prompt}
        ]
    )
    return response['choices'][0]['message']['content']


openai.api_key = 'keys'
coco_annotation_file = './data/datasets/coco/annotations/captions_train2014.json'
coco_data_dir = './data/datasets/coco/train2014'
coco = COCO(coco_annotation_file)
image_ids = coco.getImgIds()
images = coco.loadImgs(image_ids)
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32',
                                                             pretrained='./data/model/clip/mineclip/vitB/coco_5C_3_100.pt')

if __name__ == '__main__':
    descriptions = {}
    for img_info in images:
        img_id = img_info['id']
        img_filename = os.path.join(coco_data_dir, img_info['file_name'])
        image_tensor = image = preprocess(Image.open(img_filename)).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(image_tensor).cpu().numpy().flatten()
        captions = coco.imgToAnns[img_id]
        if captions:
            first_caption = coco.anns[captions[0]]['caption']
        else:
            first_caption = "No caption available."
        new_description = generate_description_openai(first_caption, image_features)
        descriptions[img_info['file_name']] = new_description
    with open("./coco/mine/text/coco_image_descriptions.txt", "w") as f:
        for image_name, description in descriptions.items():
            f.write(f"Image: {image_name}\nDescription: {description}\n\n")